import glob
import logging
import os
import shutil
import time
from collections import deque
from os import path
from pathlib import Path
import numpy as np
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import threading
from queue import Queue
import asyncio
# import cProfile
import time

import torch
import torch.nn as nn
from torch import multiprocessing as mp
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]=""
# print(torch.get_num_threads())
# torch.set_num_threads(torch.get_num_threads())
# from torch.profiler import profile, record_function, ProfilerActivity
from sacred import Experiment
from sacred.observers import (  # noqa
    FileStorageObserver,
    MongoObserver,
    QueuedMongoObserver,
    QueueObserver,
)
from torch.utils.tensorboard import SummaryWriter

import utils
from a2c import CommA2C, algorithm
from envs import make_vec_envs
from wrappers import FlattenObservation, RecordEpisodeStatistics, SquashDones, GlobalizeReward
from model import CommPolicy
from nn_modules import CommHeadAligner

import wandb

import rware # noqa
import lbforaging # noqa
from vmas import make_env as vmas_make_env

ex = Experiment(ingredients=[algorithm])
ex.captured_out_filter = lambda captured_output: "Output capturing turned off."
ex.observers.append(FileStorageObserver("./results/sacred"))

logging.basicConfig(
    level=logging.INFO,
    format="(%(process)d) [%(levelname).1s] - (%(asctime)s) - %(name)s >> %(message)s",
    datefmt="%m/%d %H:%M:%S",
)

@ex.config
def config():
    env_name = None
    time_limit = None
    share_reward = False
    wrappers = [
        RecordEpisodeStatistics,
        SquashDones,
        FlattenObservation
    ]
    if(share_reward):
        wrappers.append(GlobalizeReward)
    wrappers = tuple(wrappers)
    dummy_vecenv = False

    num_env_steps = 100e6
    env_configs = {}
    # This is the communication interval - agent can only communicate every n steps

    eval_dir = "./results/video/{id}"
    loss_dir = "./results/loss/{id}"
    save_dir = "./results/trained_models/{id}"

    log_interval = 2000
    save_interval = int(5e5)
    # save_interval = None
    eval_interval = int(1e4)
    episodes_per_eval = 20


for conf in glob.glob("configs/*.yaml"):
    name = str(Path(conf).stem)
    ex.add_named_config(name, conf)

def _squash_info(info, env_name, eval = False):
    info = [i for i in info if i]
    new_info = {}
    keys = set([k for i in info for k in i.keys()])
    keys.discard("TimeLimit.truncated")

    if('MarlGrid' in env_name):
        new_info['episode_reward'] = np.mean(np.array([i['episode_reward'].sum() for i in info]))
        if(eval):
            new_info['episode_length'] = np.mean(np.array([i['episode_length'] for i in info]))
        # new_info['episode_reward'] = []
        # for i in info:
        #     r = 0.0
        #     for k_idx in range(len(i['rew_by_act'][0].keys())):
        #         r += i['rew_by_act'][0]['agent_' + str(int(k_idx))]
        #     new_info['episode_reward'].append(r)
        # new_info['episode_reward'] = np.array(new_info['episode_reward'])
    else:
        for key in keys:
            mean = np.mean([np.array(d[key]).sum() for d in info if key in d]) if 'MSMTC' not in env_name else np.mean([np.array(d[key]).mean() for d in info if key in d])
            new_info[key] = mean
    return new_info

# This coordinates the message each agent should receive
def _coordinate_messages(n_agents, n_messages, concat_messages):
    coordinated_messages = []
    if(torch.is_tensor(n_messages)):
        n_messages = n_messages.chunk(n_messages.size(0))
        n_messages = [nm.squeeze(0) for nm in n_messages]
    for a_i in range(n_agents):
        # messages_for_a_i = [m.clone() for m in n_messages[:a_i] + n_messages[a_i + 1 :]]
        messages_for_a_i = n_messages[:a_i] + n_messages[a_i + 1 :]
        # Original - concat messages
        if concat_messages:
            coordinated_messages.append(torch.cat(messages_for_a_i, dim = 1))
        else:
            coordinated_messages.append(torch.mean(torch.stack(messages_for_a_i, dim = 1), dim = 1))
    return torch.stack(coordinated_messages)

# This prepares the tensor for distributed processing when agents act
def _prepare_n_act_tensors(n_agents, num_processes, hidden_size):
    n_value = torch.zeros(n_agents, num_processes, 1)
    n_value.share_memory_()
    n_action = torch.zeros(n_agents, num_processes, 1)
    n_action.share_memory_()
    n_action_log_prob = torch.zeros(n_agents, num_processes, 1)
    n_action_log_prob.share_memory_()
    n_messages = torch.zeros(n_agents, num_processes, 3)
    n_messages.share_memory_()
    n_recurrent_hidden_states = torch.zeros(n_agents, num_processes, hidden_size)
    n_recurrent_hidden_states.share_memory_()
    return n_value, n_action, n_action_log_prob, n_messages, n_recurrent_hidden_states

@ex.capture
def evaluate(
    agents,
    monitor_dir,
    episodes_per_eval,
    env_name,
    seed,
    wrappers,
    dummy_vecenv,
    time_limit,
    algorithm,
    env_configs,
    _log,
):
    device = algorithm["device"]

    eval_envs = make_vec_envs(
        env_name,
        env_configs,
        seed,
        dummy_vecenv,
        episodes_per_eval,
        env_configs["time_limit"],
        wrappers,
        device,
        env_properties = algorithm['env_properties']
    )

    n_obs = eval_envs.reset()
    n_recurrent_hidden_states = [
        torch.zeros(
            episodes_per_eval, agent.model.recurrent_hidden_state_size, device=device
        )
        for agent in agents
    ]

    if(algorithm['use_comm_sep_rnn']):
        n_comm_recurrent_hidden_states = [
            torch.zeros(
                episodes_per_eval, agent.model.recurrent_hidden_state_size, device=device
            )
            for agent in agents
        ]


    n_masks = torch.zeros(episodes_per_eval, 1, device=device)
    num_other_agents = len(eval_envs.observation_space) - 1
    num_agents = num_other_agents + 1

    if(algorithm["concat_messages"]):
        msgs = torch.zeros((len(eval_envs.observation_space), episodes_per_eval, num_other_agents * algorithm['num_comm_outputs']), device=device)
    else:
        msgs = torch.zeros((len(eval_envs.observation_space), episodes_per_eval, algorithm['num_comm_outputs']), device=device)

    all_infos = []


    comm_flag = 0
    comm_interval = algorithm['comm_interval']
    packaged_n_messages = None
    while len(all_infos) < episodes_per_eval:
        with torch.no_grad():
            if(algorithm['use_comm_sep_rnn']):
                _, n_action, _, n_messages, n_recurrent_hidden_states, n_comm_recurrent_hidden_states = zip(
                    *[
                        agent.model.act(
                            n_obs[agent.agent_id].float(), msg, recurrent_hidden_states.float(), n_masks, crnn_hxs = comm_recurrent_hidden_states.float()
                        )
                        for agent, msg, recurrent_hidden_states, comm_recurrent_hidden_states, in zip(
                            agents, (packaged_n_messages if type(packaged_n_messages) != type(None) else msgs), n_recurrent_hidden_states, n_comm_recurrent_hidden_states
                        )
                    ]
                )
            else:
                _, n_action, _, n_messages, n_recurrent_hidden_states, _ = zip(
                    *[
                        agent.model.act(
                            n_obs[agent.agent_id].float() if 'MarlGrid' not in env_name else (torch.tensor(n_obs[agent.agent_id][0]).float(), n_obs[agent.agent_id][1].float()), msg, recurrent_hidden_states.float(), n_masks
                        )
                        for agent, msg, recurrent_hidden_states in zip(
                            agents, (packaged_n_messages if type(packaged_n_messages) != type(None) else msgs), n_recurrent_hidden_states
                        )
                    ]
                )

        # Obser reward and next obs
        n_obs, _, done, infos = eval_envs.step(n_action)

        # Package messages to the correct format
        if(comm_flag == 0):
            packaged_n_messages = _coordinate_messages(num_agents, n_messages, algorithm["concat_messages"])
        else:
            packaged_n_messages = _coordinate_messages(num_agents, msgs, algorithm["concat_messages"])

        n_masks = torch.tensor(
            [[0.0] if done_ else [1.0] for done_ in done],
            dtype=torch.float32,
            device=device,
        )
        comm_flag = (comm_flag + 1) % comm_interval
        for i, info in enumerate(infos):
            if("predator_prey" in env_name or "PredatorPrey" in env_name or "TrafficJunction" in env_name or "MSMTC" in env_name or "MarlGrid" in env_name):
                if('episode_reward' in info.keys()):
                    all_infos.append(info)
            else:
                if info:
                    all_infos.append(info)

    eval_envs.close()
    info = _squash_info(all_infos, env_name, eval = True)

    if("TrafficJunction" in env_name):
        return info['episode_reward'], info['success']
    elif("MarlGrid" in env_name):
        return info['episode_reward'], info['episode_length']
    return info['episode_reward']

#@profile
def true_main(
    _run,
    _log,
    num_env_steps,
    env_name,
    seed,
    algorithm,
    dummy_vecenv,
    time_limit,
    env_configs,
    wrappers,
    save_dir,
    eval_dir,
    loss_dir,
    log_interval,
    save_interval,
    eval_interval,
    share_reward,
):

    # set seed
    utils.set_seed(seed)

    # if loss_dir:
    #     loss_dir = path.expanduser(loss_dir.format(id=str(seed)))
    #     if(os.path.isdir(loss_dir) == False):
    #         os.mkdir(loss_dir)
    #     utils.cleanup_log_dir(loss_dir)
    #     writer = SummaryWriter(loss_dir)
    # else:
    #     writer = None

    eval_dir = path.expanduser(eval_dir.format(id=str(seed)))
    if(os.path.isdir(eval_dir) == False):
        os.mkdir(eval_dir)
    save_dir = path.expanduser(save_dir.format(id=str(seed)))
    if(os.path.isdir(save_dir) == False):
        os.mkdir(save_dir)

    utils.cleanup_log_dir(eval_dir)
    utils.cleanup_log_dir(save_dir)

    # torch.set_num_threads(1)
    envs = make_vec_envs(
        env_name,
        env_configs,
        seed,
        dummy_vecenv,
        algorithm["num_processes"],
        # 2,
        env_configs["time_limit"],
        wrappers,
        algorithm["device"],
        env_properties= algorithm['env_properties']
    )

    num_other_agents = len(envs.observation_space) - 1
    num_agents = num_other_agents + 1
    comm_interval = algorithm['comm_interval']

    agents = [
        CommA2C(env_name, i, osp, asp, num_other_agents)
        for i, (osp, asp) in enumerate(zip(envs.observation_space, envs.action_space))
    ]
    for a in agents:
        a.model.share_memory()

    # for a in agents:
    #     print("seed {} weight sum {}".format(seed, torch.sum(a.model.base.comm_rl_head_networks.actor[0].weight)))
    # exit()

    obs = envs.reset()

    # Original - concat all other agents' messages
    if(algorithm["concat_messages"]):
        msg = torch.zeros((len(envs.observation_space), algorithm['num_processes'], num_other_agents * algorithm['num_comm_outputs']))
    else:
        msg = torch.zeros((len(envs.observation_space), algorithm['num_processes'], algorithm['num_comm_outputs']))

    for i in range(len(obs)):
        if('MarlGrid' not in env_name):
            agents[i].storage.obs[0].copy_(obs[i])
        else:
            agents[i].storage.img_obs[0].copy_(obs[i][:][0])
            agents[i].storage.df_obs[0].copy_(obs[i][:][1])
        # agents[i].storage.msg[0].copy_(msg[i])
        agents[i].storage.msg.append(msg[i].clone())
        agents[i].storage.to(algorithm["device"])

    # # # Init wandb
    if(algorithm['env_properties'] != None):
        group_name =  algorithm["algorithm_name"] + "_" + env_name + '_sr' + str(algorithm['env_properties']['sensor_range']) + '_' + 'rqs' + str(algorithm['env_properties']['request_queue_size']) + '_' + ("_shared_reward" if share_reward else "") + ("_subproc" if dummy_vecenv else "")
    else:
        group_name =  algorithm["algorithm_name"] + "_" + env_name + ("_shared_reward" if share_reward else "") + ("_subproc" if dummy_vecenv else "")
    print("Algo to run: {}".format(group_name))
    run = wandb.init(project='cacl', entity='ssl_ec_marl', group = group_name, config = algorithm)

    # update wandb config
    wandb.config.update({"num_env_steps" : num_env_steps, "seed" : seed})

    # Assumes equally distributed environment steps
    env_steps_per_length = int(num_env_steps) / len(algorithm['num_steps_schedule'])
    updates_schedule = []
    for ss in algorithm['num_steps_schedule']:
        updates_schedule.append(env_steps_per_length // ss // algorithm["num_processes"])

    assert len(updates_schedule) == len(algorithm['num_steps_schedule'])

    start = time.time()
    # num_updates = (
    #     int(num_env_steps) // algorithm["num_steps"] // algorithm["num_processes"]
    # )

    all_infos = deque(maxlen=12)

    # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    #     with record_function("training loop update"):
    # with ThreadPoolExecutor(max_workers=20) as executor:
    comm_flag = 0
    total_update_steps = 0
    loss_dict = {"policy_loss":0.0,
            "value_loss": 0.0,
            "aligner_loss": 0.0,
            "pl_loss": 0.0,
            "total_loss": 0.0,
            "grad_norm": 0.0}
    for u_idx, update_step in enumerate(updates_schedule):
        for j in range(1, int(update_step) + 1):
            total_update_steps += 1
            for step in range(algorithm['num_steps_schedule'][u_idx]):
                # Sample actions

                # Parallelized version - Python threading library
                # n_value, n_action, n_action_log_prob, n_messages, n_recurrent_hidden_states = _prepare_n_act_tensors(num_agents, algorithm['num_processes'], 32)
                # para_vars = (n_value, n_action, n_action_log_prob, n_messages, n_recurrent_hidden_states)
                # threads = []
                # for agent_idx in range(num_agents):
                #     t = threading.Thread(target = agents[agent_idx].model.para_act, args=(agent_idx,
                #                                                               para_vars,
                #                                                               agents[agent_idx].storage.obs[step].clone().detach(),
                #                                                               agents[agent_idx].storage.msg[step].clone().detach(),
                #                                                               agents[agent_idx].storage.recurrent_hidden_states[step].clone().detach(),
                #                                                               agents[agent_idx].storage.masks[step].clone().detach(), True,
                #                                                               ), )
                #     t.start()
                #     threads.append(t)
                # [t.join() for t in threads]

                # Original - non distributed action sampling
                # This is commented out as we want to backpropagate through the message tensors
                # with torch.no_grad():
                if(algorithm['use_comm_sep_rnn']):
                    n_value, n_action, n_action_log_prob, n_messages, n_recurrent_hidden_states, n_comm_recurrent_hidden_states= zip(
                        *[
                            agent.model.act(
                                # torch.cat((agent.storage.obs[step], agent.storage.msg[step]), dim = 1),
                                agent.storage.obs[step].clone().detach() if 'MarlGrid' not in env_name else (agent.storage.img_obs[step].clone().detach(), agent.storage.df_obs[step].clone().detach()),
                                agent.storage.msg[step].clone().detach(),
                                agent.storage.recurrent_hidden_states[step].clone().detach(),
                                agent.storage.masks[step].clone().detach(),
                                comm_partial_with_grad = True,
                                crnn_hxs = agent.storage.comm_recurrent_hidden_states[step].clone().detach()
                            )
                            for agent in agents
                        ]
                    )
                else:
                    n_value, n_action, n_action_log_prob, n_messages, n_recurrent_hidden_states, _ = zip(
                        *[
                            agent.model.act(
                                # torch.cat((agent.storage.obs[step], agent.storage.msg[step]), dim = 1),
                                agent.storage.obs[step].clone().detach() if 'MarlGrid' not in env_name else (agent.storage.img_obs[step].clone().detach(), agent.storage.df_obs[step].clone().detach()),
                                agent.storage.msg[step].clone().detach(),
                                agent.storage.recurrent_hidden_states[step].clone().detach(),
                                agent.storage.masks[step].clone().detach(),
                                comm_partial_with_grad = True
                            )
                            for agent in agents
                        ]
                    )

                # Obser reward and next obs
                obs, reward, done, infos = envs.step(n_action)

                # envs.envs[0].render()

                # If done then clean the history of observations.
                masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])

                bad_masks = torch.FloatTensor(
                    [
                        [0.0] if info.get("TimeLimit.truncated", False) else [1.0]
                        for info in infos
                    ]
                )

                # Package messages to the correct format
                if(comm_flag == 0):
                    packaged_n_messages = _coordinate_messages(num_agents, n_messages, algorithm["concat_messages"])
                else:
                    packaged_n_messages = msg

                for i in range(len(agents)):
                    agents[i].storage.insert(
                        obs[i],
                        packaged_n_messages[i],
                        n_recurrent_hidden_states[i].detach(),
                        n_action[i],
                        n_action_log_prob[i].detach(),
                        n_value[i].detach(),
                        reward[:, i].unsqueeze(1) / algorithm['reward_scale'],
                        masks,
                        bad_masks,
                        n_comm_recurrent_hidden_states[i].detach() if algorithm['use_comm_sep_rnn'] else None
                    )
                    agents[i].storage.to(algorithm["device"])

                for info in infos:
                    if("predator_prey" in env_name or "PredatorPrey" in env_name or "TrafficJunction" in env_name or "MarlGrid" in env_name):
                        if('episode_reward' in info.keys()):
                            all_infos.append(info)
                    else:
                        if info:
                            all_infos.append(info)

                comm_flag = (comm_flag + 1) % comm_interval

            # value_loss, action_loss, dist_entropy = agent.update(rollouts)
            for agent in agents:
                agent.compute_returns()

            # Zero grad first
            for agent in agents:
                    agent.optimizer.zero_grad()

            # for agent in agents:
            #     for param in agent.model.parameters():
            #         param.grad = None

            # Compute loss and perform backprop

            # Parallelized version - Python threading library
            # threads = []
            # for a_idx, agent in enumerate(agents):
            #     if(algorithm['use_aligner']):
            #         if(algorithm['use_shared_aligner']):
            #             t = threading.Thread(target = agent.update, args = (agents, aligners[0],), )
            #         else:
            #             t = threading.Thread(target = agent.update, args = (agents, aligners[a_idx],), )
            #     else:
            #         t = threading.Thread(target = agent.update, args = (agents, None,), )
            #     t.start()
            #     threads.append(t)
            # [t.join() for t in threads]

            threads = []
            for a_idx, agent in enumerate(agents):
                # t = threading.Thread(target = agent.update, args = (agents,), )
                t = utils.ThreadWithReturnValue(target = agent.update, args = (agents,), )
                t.start()
                threads.append(t)
            per_agent_losses = [t.join() for t in threads]

            # per_agent_losses = []
            # for agent in agents:
            #     loss = agent.update(agents)
            #     per_agent_losses.append(loss)

            # Process losses
            for k in loss_dict.keys():
                sum_loss = 0.0
                for a in per_agent_losses:
                    sum_loss += a[k]
                loss_dict[k] += sum_loss / float(len(per_agent_losses))

            # if(j == 10):
            #     print("seed {} j {} per agent losses: {}".format(seed, j, per_agent_losses))
            #     exit()

            # Original - non distributed update
            # for agent in agents:
            #     loss = agent.update(agents)

            #     for k, v in loss.items():
            #         if writer:
            #             writer.add_scalar(str("agent{}/{}").format(agent.agent_id, k), v, j)
            # Perform update at once to avoid updating the model one by one which would give error as each agent's update would also update models of other agents
            for agent in agents:
                # Given the number of updates. Here scales the max_gras_norm by the number of agents
                if(algorithm['use_clipping'] == True):
                    nn.utils.clip_grad_norm_(agent.model.parameters(), algorithm['max_grad_norm'])
                agent.optimizer.step()

            for agent in agents:
                agent.storage.after_update()

            if j % log_interval == 0 and len(all_infos) > 1:
                squashed = _squash_info(all_infos, env_name)
                if(u_idx > 0):
                    total_num_steps = 0
                    for u_idx_2 in range(u_idx):
                        total_num_steps += updates_schedule[u_idx_2] * algorithm["num_processes"] * algorithm['num_steps_schedule'][u_idx_2]
                    total_num_steps += (j + 1) * algorithm["num_processes"]  * algorithm['num_steps_schedule'][u_idx]

                else:
                    total_num_steps = (
                        (j + 1) * algorithm["num_processes"] * algorithm["num_steps"]
                    )
                end = time.time()
                _log.info(
                    str("Updates {}, num timesteps {}, FPS {}".format(total_update_steps, total_num_steps, int(total_num_steps / (end - start))))
                )
                _log.info(
                    str("Last {} training episodes mean reward {:.3f}").format(len(all_infos), squashed['episode_reward'].sum())
                )
                # # Wandb logging
                wandb.log({'episode_reward': squashed['episode_reward'].sum(), 'total_num_steps': total_num_steps, 'updates' : total_update_steps, 'policy_loss' : loss_dict['policy_loss'], "value_loss" : loss_dict['value_loss'], "aligner_loss": loss_dict['aligner_loss'], "pl_loss": loss_dict['pl_loss'], "grad_norm": loss_dict['grad_norm']})
                loss_dict = {"policy_loss":0.0, "value_loss": 0.0, "aligner_loss": 0.0, "pl_loss": 0.0, "total_loss": 0.0, "grad_norm": 0.0}

                # for k, v in squashed.items():
                #     _run.log_scalar(k, v, j)
                all_infos.clear()

            if save_interval is not None and (
                j > 0 and j % save_interval == 0 or j == update_step
            ):
                for agent in agents:
                    save_at_wandb = os.path.join(wandb.run.dir, group_name + str("_agent{}_seed{}").format(agent.agent_id, seed))
                    os.makedirs(save_at_wandb, exist_ok=True)
                    agent.save(save_at_wandb, update_step)

            if eval_interval is not None and (
                j > 0 and j % eval_interval == 0 or j == update_step

            ):
                if("TrafficJunction" in env_name):
                    eval_reward, eval_success = evaluate(agents, os.path.join(eval_dir, str("u{}").format(j)), env_configs = env_configs,)
                    wandb.log({'eval_updates': total_update_steps, 'eval_reward': eval_reward, 'eval_success': eval_success})
                elif("MarlGrid" in env_name):
                    eval_reward, eval_episode_length = evaluate(agents, os.path.join(eval_dir, str("u{}").format(j)), env_configs = env_configs,)
                    wandb.log({'eval_updates': total_update_steps, 'eval_reward': eval_reward, 'eval_episode_length': eval_episode_length})
                else:
                    eval_reward = evaluate(agents, os.path.join(eval_dir, str("u{}").format(j)), env_configs = env_configs,)
                    wandb.log({'eval_updates': total_update_steps, 'eval_reward': eval_reward})
                # videos = glob.glob(os.path.join(eval_dir, str("u{}").format(j)) + "/*.mp4")
                # for i, v in enumerate(videos):
                #     _run.add_artifact(v, str("u{}.{}.mp4").format(j, i))
            # if(j == 1000):
            #     # print("forward time: {}, backward time: {}".format(sum(f_times), sum(b_times)))
            #     print("1000 updates")
            #     break
        # print(prof.key_averages().table(sort_by="cpu_time_total"))
        # exit()
        envs.close()

@ex.automain
def main(
    _run,
    _log,
    num_env_steps,
    env_name,
    seed,
    algorithm,
    dummy_vecenv,
    time_limit,
    env_configs,
    wrappers,
    save_dir,
    eval_dir,
    loss_dir,
    log_interval,
    save_interval,
    eval_interval,
    share_reward,
):
    true_main(_run,_log,num_env_steps,env_name,seed,algorithm,dummy_vecenv,time_limit,env_configs,wrappers,save_dir,eval_dir,loss_dir,log_interval,save_interval,eval_interval,share_reward)


# if __name__ == '__main__':
#     main()
    # import cProfile, pstats
    # profiler = cProfile.Profile()
    # profiler.enable()
    # main()
    # profiler.disable()
    # stats = pstats.Stats(profiler).sort_stats('cumtime')
    # stats.print_stats()
